韩广云,NVIDIA GPU 加速计算专家团队 高级工程师 | AI Open Day/2025-11-07
线性注意力 (Linear Attention)
扩展至 Delta Rule (Extension to Delta Rule)
研究活跃
工业级大语言模型正在涌现
关注于核函数级别的优化及其挑战
最基本的形式
可以重写为块形式 (block-wise form)
在 MiniMax 的工作中¹
展开为块级形式 (Unroll for chunk-wise form)
调度 (Scheduling)
(seqlen_q / TileSizeQ) x (seqlen_k / TileSizeK) x num_heads x num_seqs(seqlen_q / TileSizeQ) 和 num_heads x num_seqs 部分是并行的 (parallelize),而 (seqlen_k / TileSizeK) 部分是串行的 (sequential)。外层循环处理 Q tiles
内层循环处理 K 或 V tiles
调度 (Scheduling)
(seqlen / TileSize) x num_heads x num_seqs(seqlen / TileSize) 部分是串行的 (sequential),而 num_heads x num_seqs 部分是并行的 (parallelize)。只有一个串行循环
每个 CTA 产生
假设 B 是 Tile Size, N 是序列长度
2B²d x (N/B - 1)(2B²d + 2B d²) x N/B总近似计算量
O(Nd²) 或 O(N B d)分析结论
上图展示了一个理想的指令顺序,旨在实现高效的流水线操作,包括数据加载(Load Q, K, V)、数学计算(Math WG1, WG2)和结果存储(Store O)。图中展示了 Acquire/Release 各种 pipe(Q, K, V, O pipe)以协调数据流,最终计算出完整的 tile。
中间状态分布在数学工作组寄存器 (Math WG Registers) 中
最大 TileSize 仅为 64
WGMMA 的排序由最小化每个缓冲区的生命周期决定
WGMMA 配置
硬件映射要点
S(状态)存在于寄存器中。
所有设计都是将 S 放入寄存器的结果。
基准测试环境
nvidia-smi -lgc 1000,1000 锁定 GPU 时钟。通过 CUDA Events 测量
完全融合的线性注意力 (Fully fused LA) 可以更快
vllm 的速度较慢
-因为它在 kernel 外部使用一个循环来处理可变长度的输入。
基于调度
(seqlen / TileSize) x num_heads x num_seqs(seqlen / TileSize),并行部分 num_heads x num_seqs可以实现的并行度远小于 FMHA (Flash Multi-Head Attention)
可能导致硬件利用率低,为什么?
num_seqs,也称为 batch size。以四个计算单元为例:
尚未实现。
仅关注分块形式 (chunk-wise form)
在 Linear Attention 中替换 V
元素级处理(Elementwise processing)是不同的。
计算流程
新特性:
KK和其逆矩阵的计算。
计算NewV。
但是现在我们有很多 exp2f 和 log2f 的计算。
下图展示了在Hopper GPU上实现全融合Delta法则的时间线图。该图解了不同工作组(Math WG1, Math WG2)中各种计算任务(如加载Q/K/V,计算T=KK,O1=QS,V-SK等)的并行与依赖关系。
直接前向代换很简单。
但是非常昂贵!
在对角线上使用小块矩阵进行计算。
分块求逆来解决问题。
内存高效。
计算高效。
下图展示了分块求逆的过程,其中绿色部分代表原始矩阵,浅绿色部分代表求逆后的矩阵。
将T和P的计算移至独立的流水线。
之前:
当前:
提升了性能。
下图展示了3阶段流水线的执行流程,对比了不同工作组(Aux Math WG1, Math WG2, and WG3)的任务调度。
基准测试环境
GPU最高时钟频率:1785 MHz (1.78 GHz)
nvidia-smi -lgc 1000,1000内存时钟频率:2619 Mhz.
通过CUDA Events测量
全融合Delta法则可以快很多
下表展示了在固定序列长度和固定批大小两种情况下的基准测试结果,对比了fla和我们(Ours)的实现。
<font size="1">1. For technical discussion and reference only, perf. may vary based on different product portfolio.</font>
<font size="1">2. Flash attention performance is tested with version v2.5.3 commit 49b3c3b</font>
<font size="1">3. Our kernel is still in development</font>
主循环步骤中的长延迟
3个以上的WGMMA
矩阵求逆
由于门控delta法则,存在过多的逐元素处理。
精度问题
并行性问题